import os

from numpy import arctan, array, cos, exp, log, sin

from lmfit import Parameters

thisdir, thisfile = os.path.split(__file__)
NIST_DIR = os.path.join(thisdir, '..', 'NIST_STRD')


def read_params(params):
    if isinstance(params, Parameters):
        return [par.value for par in params.values()]
    else:
        return params


def Bennet5(b, x, y=0):
    b = read_params(b)
    return y - b[0] * (b[1]+x)**(-1/b[2])


def BoxBOD(b, x, y=0):
    b = read_params(b)
    model = b[0]*(1-exp(-b[1]*x))
    return model - y


def Chwirut(b, x, y=0):
    b = read_params(b)
    model = exp(-b[0]*x)/(b[1]+b[2]*x)
    return model - y


def DanWood(b, x, y=0):
    b = read_params(b)
    model = b[0]*x**b[1]
    return model - y


def ENSO(b, x, y=0):
    b = read_params(b)
    pi = 3.141592653589793238462643383279
    model = b[0] + (b[1]*cos(2*pi*x/12) + b[2]*sin(2*pi*x/12) +
                    b[4]*cos(2*pi*x/b[3]) + b[5]*sin(2*pi*x/b[3]) +
                    b[7]*cos(2*pi*x/b[6]) + b[8]*sin(2*pi*x/b[6]))
    return model - y


def Eckerle4(b, x, y=0):
    b = read_params(b)
    model = (b[0]/b[1]) * exp(-0.5*((x-b[2])/b[1])**2)
    return model - y


def Gauss(b, x, y=0):
    b = read_params(b)
    model = (b[0]*exp(-b[1]*x) + (b[2]*exp(-(x-b[3])**2 / b[4]**2) +
                                  b[5]*exp(-(x-b[6])**2 / b[7]**2)))
    return model - y


def Hahn1(b, x, y=0):
    b = read_params(b)
    model = (b[0]+b[1]*x+b[2]*x**2+b[3]*x**3) / (1+b[4]*x+b[5]*x**2+b[6]*x**3)
    return model - y


def Kirby(b, x, y=0):
    b = read_params(b)
    model = (b[0] + b[1]*x + b[2]*x**2) / (1 + b[3]*x + b[4]*x**2)
    return model - y


def Lanczos(b, x, y=0):
    b = read_params(b)
    model = b[0]*exp(-b[1]*x) + b[2]*exp(-b[3]*x) + b[4]*exp(-b[5]*x)
    return model - y


def MGH09(b, x, y=0):
    b = read_params(b)
    model = b[0]*(x**2+x*b[1]) / (x**2+x*b[2]+b[3])
    return model - y


def MGH10(b, x, y=0):
    b = read_params(b)
    model = b[0] * exp(b[1]/(x+b[2]))
    return model - y


def MGH17(b, x, y=0):
    b = read_params(b)
    model = b[0] + b[1]*exp(-x*b[3]) + b[2]*exp(-x*b[4])
    return model - y


def Misra1a(b, x, y=0):
    b = read_params(b)
    model = b[0]*(1-exp(-b[1]*x))
    return model - y


def Misra1b(b, x, y=0):
    b = read_params(b)
    model = b[0] * (1-(1+b[1]*x/2)**(-2))
    return model - y


def Misra1c(b, x, y=0):
    b = read_params(b)
    model = b[0] * (1-(1+2*b[1]*x)**(-.5))
    return model - y


def Misra1d(b, x, y=0):
    b = read_params(b)
    model = b[0]*b[1]*x*((1+b[1]*x)**(-1))
    return model - y


def Nelson(b, x, y=None):
    b = read_params(b)
    x1 = x[:, 0]
    x2 = x[:, 1]
    model = b[0] - b[1]*x1 * exp(-b[2]*x2)
    return model - log(y)


def Rat42(b, x, y=0):
    b = read_params(b)
    model = b[0] / (1+exp(b[1]-b[2]*x))
    return model - y


def Rat43(b, x, y=0):
    b = read_params(b)
    model = b[0] / ((1+exp(b[1]-b[2]*x))**(1/b[3]))
    return model - y


def Roszman1(b, x, y=0):
    b = read_params(b)
    pi = 3.141592653589793238462643383279
    model = b[0] - b[1]*x - arctan(b[2]/(x-b[3]))/pi
    return model - y


def Thurber(b, x, y=0):
    b = read_params(b)
    model = ((b[0] + b[1]*x + b[2]*x**2 + b[3]*x**3) /
             (1 + b[4]*x + b[5]*x**2 + b[6]*x**3))
    return model - y


#  Model name        fcn,    #fitting params, dim of x
Models = {'Bennett5': (Bennet5, 3, 1),
          'BoxBOD': (BoxBOD, 2, 1),
          'Chwirut1': (Chwirut, 3, 1),
          'Chwirut2': (Chwirut, 3, 1),
          'DanWood': (DanWood, 2, 1),
          'ENSO': (ENSO, 9, 1),
          'Eckerle4': (Eckerle4, 3, 1),
          'Gauss1': (Gauss, 8, 1),
          'Gauss2': (Gauss, 8, 1),
          'Gauss3': (Gauss, 8, 1),
          'Hahn1': (Hahn1, 7, 1),
          'Kirby2': (Kirby, 5, 1),
          'Lanczos1': (Lanczos, 6, 1),
          'Lanczos2': (Lanczos, 6, 1),
          'Lanczos3': (Lanczos, 6, 1),
          'MGH09': (MGH09, 4, 1),
          'MGH10': (MGH10, 3, 1),
          'MGH17': (MGH17, 5, 1),
          'Misra1a': (Misra1a, 2, 1),
          'Misra1b': (Misra1b, 2, 1),
          'Misra1c': (Misra1c, 2, 1),
          'Misra1d': (Misra1d, 2, 1),
          'Nelson': (Nelson, 3, 2),
          'Rat42': (Rat42, 3, 1),
          'Rat43': (Rat43, 4, 1),
          'Roszman1': (Roszman1, 4, 1),
          'Thurber': (Thurber, 7, 1)}


def ReadNistData(dataset):
    """NIST STRD data is in a simple, fixed format with
    line numbers being significant!
    """
    finp = open(os.path.join(NIST_DIR, f"{dataset}.dat"))
    lines = [line[:-1] for line in finp.readlines()]
    finp.close()
    ModelLines = lines[30:39]
    ParamLines = lines[40:58]
    DataLines = lines[60:]

    words = ModelLines[1].strip().split()
    nparams = int(words[0])

    start1 = [0]*nparams
    start2 = [0]*nparams
    certval = [0]*nparams
    certerr = [0]*nparams
    for i, text in enumerate(ParamLines[:nparams]):
        [s1, s2, val, err] = [float(x) for x in text.split('=')[1].split()]
        start1[i] = s1
        start2[i] = s2
        certval[i] = val
        certerr[i] = err

    for t in ParamLines[nparams:]:
        t = t.strip()
        if ':' not in t:
            continue
        val = float(t.split(':')[1])
        if t.startswith('Residual Sum of Squares'):
            sum_squares = val
        elif t.startswith('Residual Standard Deviation'):
            std_dev = val
        elif t.startswith('Degrees of Freedom'):
            nfree = int(val)
        elif t.startswith('Number of Observations'):
            ndata = int(val)

    y, x = [], []
    for d in DataLines:
        vals = [float(i) for i in d.strip().split()]
        y.append(vals[0])
        if len(vals) > 2:
            x.append(vals[1:])
        else:
            x.append(vals[1])

    y = array(y)
    x = array(x)
    out = {'y': y, 'x': x, 'nparams': nparams, 'ndata': ndata,
           'nfree': nfree, 'start1': start1, 'start2': start2,
           'sum_squares': sum_squares, 'std_dev': std_dev,
           'cert': certval, 'cert_values': certval, 'cert_stderr': certerr}
    return out
